{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(sys.path[0]))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['TOKENIZERS_PARALLELISM'] = 'True'\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as patches\n",
    "from PIL import Image,ImageDraw\n",
    "import easyocr\n",
    "import numpy as np\n",
    "import json\n",
    "import re\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "from scripts.train.utils import smart_tokenizer_and_embedding_resize,Html2BboxTree,move_to_device,BboxTree2Html,add_special_tokens\n",
    "from vars import *\n",
    "from my_dataset import UICoderDataset,UICoderCollater\n",
    "from transformers import AutoProcessor, Pix2StructForConditionalGeneration,AddedToken\n",
    "from utils import Html2BboxTree, BboxTree2StyleList, BboxTree2Html\n",
    "from datasets import Dataset\n",
    "\n",
    "torch.manual_seed(SEED)\n",
    "\n",
    "device = 'cuda:0'\n",
    "model_path = \"/data02/users/lz/code/UICoder/checkpoints/stage0/l1024_p1024_vu_3m*3/checkpoint-4000\"\n",
    "# data_path = '/data02/users/lz/code/UICoder/datasets/WebSight-format-parquet'\n",
    "data_path = '/data02/users/lz/code/UICoder/datasets/WebSight-format-parquet/arrow'\n",
    "# data_path = '/data02/starmage/datasets/cc/arrows_8-14_processed/'\n",
    "output_dir = '/data02/users/lz/code/UICoder/test_result'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processor = AutoProcessor.from_pretrained(processor_name_or_path)\n",
    "model = Pix2StructForConditionalGeneration.from_pretrained(model_path,is_encoder_decoder=True,device_map=device,torch_dtype=torch.float16)\n",
    "add_special_tokens(model, processor.tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = Dataset.load_from_disk(data_path)\n",
    "# ds = UICoderDataset(path=data_path,processor=processor,max_length=2048,max_patches=512,max_num=100,drop_longer=True,stage=1,preprocess=True, make_patches_while_training=True, workers=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image = Image.open('/data02/users/lz/code/image.jpg')\n",
    "image = ds[216]['image']\n",
    "print(image.size)\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    input = f'<!DOCTYPE html>'\n",
    "    decoder_input_ids = processor.tokenizer.encode(input,return_tensors='pt',add_special_tokens=True)[...,:-1]\n",
    "    encoding = processor(images=[image],text=[\"\"],max_patches=1024,return_tensors='pt')\n",
    "    item = {\n",
    "        'decoder_input_ids': decoder_input_ids,\n",
    "        'flattened_patches': encoding['flattened_patches'].half(),\n",
    "        'attention_mask': encoding['attention_mask']\n",
    "    }\n",
    "    item = move_to_device(item,device)\n",
    "\n",
    "    outputs = model.generate(**item,max_new_tokens=2048,eos_token_id=processor.tokenizer.eos_token_id,do_sample=True)\n",
    "            \n",
    "    prediction_html = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction_html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
